-- drop function if exists sm_sc.fv_d_aggr_slice_sum_dloss_dindepdt(int[], float[], int[]);
create or replace function sm_sc.fv_d_aggr_slice_sum_dloss_dindepdt
(
  i_indepdt_len    int[],
  -- i_depdt          float[],
  i_dloss_ddepdt   float[],
  i_cnt_per_grp    int[]
)
returns float[]
as
$$
declare   
  v_dloss_ddepdt_len_y    int         := array_length(i_dloss_ddepdt, 1);
  v_dloss_ddepdt_len_x    int         := array_length(i_dloss_ddepdt, 2);
  v_dloss_ddepdt_len_x3   int         := array_length(i_dloss_ddepdt, 3);
  v_dloss_ddepdt_len_x4   int         := array_length(i_dloss_ddepdt, 4);
  v_ret                   float[]     ;
  v_cur_y                 int         ;
  v_cur_x                 int         ;
  v_cur_x3                int         ;
  v_cur_x4                int         ;
begin
  i_cnt_per_grp := 
    sm_sc.fv_coalesce
    (
      i_indepdt_len[ : array_length(i_indepdt_len, 1) - coalesce(array_length(i_cnt_per_grp, 1), 0)] || i_cnt_per_grp
    , i_indepdt_len
    )
  ;
  if current_setting('pg4ml._v_is_debug_check', true) = '1'
  then
    if array_ndims(i_indepdt_len) <> 1
    then 
      raise exception 'unsupport for ndims of i_indepdt_len > 1';
    elsif array_ndims(i_dloss_ddepdt) <> array_length(i_indepdt_len, 1)
    then 
      raise exception 'unmatch between dims of i_indepdt, i_depdt and i_dloss_ddepdt.';
    elsif true <> any(i_indepdt_len ==` (i_cnt_per_grp *` array[v_dloss_ddepdt_len_y, v_dloss_ddepdt_len_x, v_dloss_ddepdt_len_x3, v_dloss_ddepdt_len_x4]))
    then
      raise exception 'unperfect i_indepdt_len, i_cnt_per_grp for i_dloss_ddepdt at some dims';
    elsif array_ndims(i_cnt_per_grp) > 1
    then 
      raise exception 'unsupport ndims of i_cnt_per_grp > 1.';
    elsif array_ndims(i_dloss_ddepdt) <> array_length(i_cnt_per_grp, 1) 
    then 
      raise exception 'unmatch between ndims of i_dloss_ddepdt and length of i_cnt_per_grp.';
    -- elsif array_dims(i_depdt) <> array_dims(i_dloss_ddepdt)
    -- then 
    --   raise exception 'unmatch between dims of i_depdt and i_dloss_ddepdt.';
    end if;
  end if;
    
  if i_dloss_ddepdt is null
  then 
    return null;
    
  elsif array_length(i_cnt_per_grp, 1) = 1
  then    
    v_ret := array_fill(null :: float, array[v_dloss_ddepdt_len_y] *` i_cnt_per_grp);
    for v_cur_y in 1 .. v_dloss_ddepdt_len_y
    loop 
      v_ret
        [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]]
      := 
        array_fill
        (
          i_dloss_ddepdt
            [v_cur_y]
        , i_cnt_per_grp
        ) 
      ;        
    end loop;
    
  elsif array_length(i_cnt_per_grp, 1) = 2
  then  
    v_ret := array_fill(null :: float, array[v_dloss_ddepdt_len_y, v_dloss_ddepdt_len_x] *` i_cnt_per_grp);  
    for v_cur_y in 1 .. v_dloss_ddepdt_len_y
    loop 
      for v_cur_x in 1 .. v_dloss_ddepdt_len_x
      loop 
        v_ret
          [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]]
          [(v_cur_x - 1) * i_cnt_per_grp[2] + 1 : v_cur_x * i_cnt_per_grp[2]]
        :=
          array_fill
          (
            i_dloss_ddepdt
              [v_cur_y]
              [v_cur_x]
          , i_cnt_per_grp
          ) 
        ;        
      end loop;
    end loop;
    
  elsif array_length(i_cnt_per_grp, 1) = 3
  then    
    v_ret := array_fill(null :: float, array[v_dloss_ddepdt_len_y, v_dloss_ddepdt_len_x, v_dloss_ddepdt_len_x3] *` i_cnt_per_grp);  
    for v_cur_y in 1 .. v_dloss_ddepdt_len_y
    loop 
      for v_cur_x in 1 .. v_dloss_ddepdt_len_x
      loop 
        for v_cur_x3 in 1 .. v_dloss_ddepdt_len_x3
        loop 
          v_ret
            [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]]
            [(v_cur_x - 1) * i_cnt_per_grp[2] + 1 : v_cur_x * i_cnt_per_grp[2]]
            [(v_cur_x3 - 1) * i_cnt_per_grp[3] + 1 : v_cur_x3 * i_cnt_per_grp[3]]
          :=
            array_fill
            (
              i_dloss_ddepdt
                [v_cur_y]
                [v_cur_x]
                [v_cur_x3]
            , i_cnt_per_grp
            ) 
          ;        
        end loop;
      end loop;
    end loop;
    
  elsif array_length(i_cnt_per_grp, 1) = 4
  then    
    v_ret := array_fill(null :: float, array[v_dloss_ddepdt_len_y, v_dloss_ddepdt_len_x, v_dloss_ddepdt_len_x3, v_dloss_ddepdt_len_x4] *` i_cnt_per_grp);  
    for v_cur_y in 1 .. v_dloss_ddepdt_len_y
    loop 
      for v_cur_x in 1 .. v_dloss_ddepdt_len_x
      loop 
        for v_cur_x3 in 1 .. v_dloss_ddepdt_len_x3
        loop 
          for v_cur_x4 in 1 .. v_dloss_ddepdt_len_x4
          loop 
            v_ret
              [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]]
              [(v_cur_x - 1) * i_cnt_per_grp[2] + 1 : v_cur_x * i_cnt_per_grp[2]]
              [(v_cur_x3 - 1) * i_cnt_per_grp[3] + 1 : v_cur_x3 * i_cnt_per_grp[3]]
              [(v_cur_x4 - 1) * i_cnt_per_grp[4] + 1 : v_cur_x4 * i_cnt_per_grp[4]]
            :=
              array_fill
              (
                i_dloss_ddepdt
                  [v_cur_y]
                  [v_cur_x]
                  [v_cur_x3]
                  [v_cur_x4]
              , i_cnt_per_grp
              ) 
            ;        
          end loop;
        end loop;
      end loop;
    end loop;
    
  end if;
  
  return v_ret;
end
$$
language plpgsql stable
parallel safe
cost 100;

-- select 
--   sm_sc.fv_d_aggr_slice_sum_dloss_dindepdt
--   (
--     array[2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--   , array[2]
--   ) :: decimal[] ~=` 3

-- select 
--   sm_sc.fv_d_aggr_slice_sum_dloss_dindepdt
--   (
--     array
--     [
--       [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--     , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--     ]
--   , array[3, 2]
--   ) :: decimal[] ~=` 3

-- select 
--   sm_sc.fv_d_aggr_slice_sum_dloss_dindepdt
--   (
--     array
--     [
--       [
--         [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--       , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--       ]
--     , [
--         [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--       , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--       ]
--     ]
--   , array[2, 2, 4]
--   ) :: decimal[] ~=` 3

-- select 
--   sm_sc.fv_d_aggr_slice_sum_dloss_dindepdt
--   (
--     array
--     [
--       [
--         [
--           [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         ]
--       , [
--           [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         ]
--       ]
--     , [
--         [
--           [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         ]
--       , [
--           [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         ]
--       ]
--     ]
--   , array[2, 2, 4, 3]
--   ) :: decimal[] ~=` 3